import os
import random
from collections import OrderedDict
from typing import Iterable, List

import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad
from torch.optim import Optimizer

from lib.utils.logger import logger
from lib.utils.misc import CONST


def batch_to_device(batch, device):
    if isinstance(batch, torch.Tensor):
        return batch.to(device)
    elif isinstance(batch, dict):
        return {k: batch_to_device(v, device) for k, v in batch.items()}
    elif isinstance(batch, (list, tuple)):
        return [batch_to_device(v, device) for v in batch]
    else:
        return batch


def batch_to_cpu(batch):
    if isinstance(batch, torch.Tensor):
        return batch.detach().cpu().numpy()
    elif isinstance(batch, dict):
        return {k: batch_to_cpu(v) for k, v in batch.items()}
    elif isinstance(batch, (list, tuple)):
        return [batch_to_cpu(v) for v in batch]
    else:
        return batch


def freeze_batchnorm_stats(model):
    for module in model.modules():
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module.momentum = 0
    for name, child in model.named_children():
        freeze_batchnorm_stats(child)


def recurse_freeze(model):
    for module in model.modules():
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module.momentum = 0
        for param in module.parameters():
            param.requires_grad = False

    for name, child in model.named_children():
        recurse_freeze(child)


def build_optimizer(params: Iterable, cfg) -> Optimizer:
    if cfg.OPTIMIZER in ["Adam", "adam"]:
        return torch.optim.Adam(
            params,
            lr=cfg.LR,
            weight_decay=float(cfg.get("WEIGHT_DECAY", 0.0)),
        )
    elif cfg.OPTIMIZER == "AdamW":
        return torch.optim.AdamW(
            params,
            lr=cfg.LR,
            weight_decay=float(cfg.get("WEIGHT_DECAY", 0.01)),
        )

    elif cfg.OPTIMIZER in ["SGD", "sgd"]:
        return torch.optim.SGD(
            params,
            lr=cfg.LR,
            momentum=float(cfg.get("MOMENTUM", 0.0)),
            weight_decay=float(cfg.get("WEIGHT_DECAY", 0.0)),
        )
    else:
        raise NotImplementedError(f"{cfg.OPTIMIZER} not be implemented yet")


def build_scheduler(optimizer: Optimizer, cfg):
    scheduler = cfg.SCHEDULER
    lr_decay_step = cfg.get("LR_DECAY_STEP", -1)
    tar_scheduler = None

    if isinstance(lr_decay_step, list) and scheduler == "StepLR":
        tar_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=cfg.LR_DECAY_STEP,
            gamma=cfg.LR_DECAY_GAMMA,
        )

    elif scheduler == "StepLR":
        tar_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            cfg.LR_DECAY_STEP,
            gamma=cfg.LR_DECAY_GAMMA,
        )

    elif scheduler == "MultiStepLR":
        tar_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=cfg.LR_DECAY_STEP,
            gamma=cfg.LR_DECAY_GAMMA,
        )
    else:
        raise NotImplementedError(f"{scheduler} not yet be implemented")

    return tar_scheduler


def clip_gradient(optimizer, max_norm, norm_type):
    """Clips gradients computed during backpropagation to avoid explosion of gradients.

    Args:
        optimizer (torch.optim.optimizer): optimizer with the gradients to be clipped
        max_norm (float): max norm of the gradients
        norm_type (float): type of the used p-norm
    """
    for group in optimizer.param_groups:
        for param in group["params"]:
            clip_grad.clip_grad_norm_(param, max_norm, norm_type)


def setup_seed(seed, conv_repeatable=True):
    """Setup all the random seeds

    @NOTE: The spawned child processes do not inherit the seed you set manually in the parent process, therefore you need to set the seed in the main_worker function.

    Args:
        seed (int or float): seed value
        conv_repeatable (bool, optional): Whether the conv ops are repeatable (depend on cudnn). Defaults to True.
    """
    logger.warning(f"setup random seed : {seed}")
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if conv_repeatable:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


def worker_init_fn(worker_id):
    """ For multi-process data loading, we need to set the seed for each worker explicitly.

    @NOTE: https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading
    By default, each worker will have its PyTorch seed set to base_seed + worker_id, 
    where base_seed is a long generated by main process using its RNG (thereby, consuming 
    a RNG state mandatorily) or a specified generator. However, seeds for other libraries 
    may be duplicated upon initializing workers, causing each worker to return identical 
    random numbers.

    Args:
        worker_id (int): worker id
    """

    # @NOTE, by default the torch's randomness:
    seed = (worker_id + int(torch.initial_seed())) % CONST.INT_MAX

    # if your dataset&aug has other types of randomness,
    # you should to set them one by one.
    np.random.seed(seed)
    random.seed(seed)


def param_count(net):
    return sum(p.numel() for p in net.parameters()) / 1e6


def param_size(net):
    # ! treat all parameters to be float
    return sum(p.numel() for p in net.parameters()) * 4 / (1024 * 1024)


def load_weights(moudle: nn.Module, pretrained=None, strict=True):
    if pretrained == "" or pretrained is None:
        logger.warning(f"=> No pretrained for {type(moudle).__name__}")
    elif os.path.isfile(pretrained):
        logger.info(f"=> Loading {type(moudle).__name__} pretrained model from: {pretrained}")
        checkpoint = torch.load(pretrained, map_location=torch.device("cpu"))
        if isinstance(checkpoint, OrderedDict):
            state_dict = checkpoint
        elif isinstance(checkpoint, dict) and "state_dict" in checkpoint:
            state_dict_old = checkpoint["state_dict"]
            state_dict = OrderedDict()
            # delete 'module.' because it is saved from DataParallel module
            for key in state_dict_old.keys():
                if key.startswith("module."):
                    # state_dict[key[7:]] = state_dict[key]
                    # state_dict.pop(key)
                    state_dict[key[7:]] = state_dict_old[key]  # delete "module." (in nn.parallel)
                else:
                    state_dict[key] = state_dict_old[key]
        else:
            logger.error(f"=> No state_dict found in checkpoint file {pretrained}")
            raise RuntimeError()

        moudle.load_state_dict(state_dict, strict=strict)
        logger.info(f"=> Loading SUCCEEDED")
    else:
        logger.error(f"=> No {type(moudle).__name__} checkpoints file found in {pretrained}")
        raise FileNotFoundError()
